# -*- coding: utf-8 -*-
"""
Created on Wed Mar  4 09:42:35 2020

@author: Shichun Hu

Content: Functions for generating demand
"""
import numpy as np
from itertools import product
#from gen_helper import *
#from scipy.stats import truncnorm
#import matplotlib.pyplot as plt


class Passenger:
    '''
    This class is to store all passenger related parameters
    '''
    def __init__(self, ID, alpha, O, D, int_time, cfg):
        self.ID = ID # a integer starting from 0
        self.alpha = alpha
        self.O = O # a vector storing the origin coordinates of the passenger
        self.D = D # a vector storing the destination coordinates of the passenger
        self.int_time = int_time
        self.IVT = 0 # min, in vehicle time
        self.IVM = 0.0 #mile, in vehicle mile
        self.onBoard = False # a boolean to indicate whether on board or not
        self.dir_time = alpha/cfg.speed # direct travel time
        self.MIVT = 1.5*self.dir_time # min, maximum IVT
        self.W = 2.0*self.alpha # willingness to pay level, currently set at infinity
        #self.traj = [] # a list of shared cost per alpha for this passenger
        self.v_info = [] # [v.ID, seq_idx], if empty then passenger not served
        
    def update(self, cur_IVT):
        self.IVT = cur_IVT


class Vehicle:
    '''
    This class is to store all driver related parameters
    '''
    def __init__(self, ID, F, O, D, cfg):
        self.ID = ID # a integer starting from 0
        self.F = F
        self.O = O # a vector storing the origin coordinates of the driver
        self.D = D # a vector storing the destination coordinates of the driver
        self.c = 1.5 #the factor for maximum total miles the driver can travel
        self.t_lim = F*self.c/cfg.speed # time limit
        self.cap = cfg.cap
        self.beta_info = cfg.beta_info
        self.eligible = cfg.eligible
        self.base_speed = cfg.speed
        #below is parameters related to operating status
        self.loc = self.O # current location
        self.lim_left = F*self.c # miles left
        self.time_left = F*self.c/cfg.speed # time left
        self.unserved = [] # initialized unserved node list
        self.miles_done = 0 # miles traveled
        self.time_done = 0 # time spent
        self.pas_seq = [] # initialized served passengers list (contain pas obj)
        self.ld_hp = 0 # current load of the vehicle
        self.curSpeed = cfg.speed # recording the current speed of vehicle, affected by HOV betas
        self.prev_node_ID = ID
    
    def updateSpeed(self):
        if len(self.unserved) != 0:
            seg_eligible = self.eligible[self.prev_node_ID][self.unserved[0]]
            if self.ld_hp + 1 >= seg_eligible:
                self.curSpeed = self.base_speed / self.beta_info[self.prev_node_ID][self.unserved[0]]
            else:
                self.curSpeed = self.base_speed
        
        
def betas(n, m, percent_list, beta_control):
    '''
    Inputs:
        percent_list has length of 3, percent of road segements in the system that has "eligible" num of 1, 2, 3 respectively
        sum(percent_list) = 1.0
        beta_control has length of 3, provide the beta value for corresponding "eligible" num
    Outputs:
        eligible[i][j] and betas[i][j] stores (eligible, beta) from node i to node j (based on loc not dev_loc but apply to dev_loc)
        if vehicle at node i reach "eligible" num of occupant, will gain a "beta" boost in time
    '''
    node_num = 2*(n+m)
    beta_info = np.ones((node_num,node_num))
    eligible = np.ones((node_num,node_num), dtype=int)
    
    numbers = np.random.choice([1,2,3], size=(node_num)*(node_num-1), p=percent_list)
    num_idx = 0
    for i in range(node_num):
        for j in range(node_num):
            if i==j:
                continue
            else:
                eligible[i][j] = numbers[num_idx]
                eligible[j][i] = eligible[i][j]
                
                # change beta_info
                if numbers[num_idx] == 1:
                    beta_info[i][j] = beta_control[0]
                    beta_info[j][i] = beta_info[i][j]
                if numbers[num_idx] == 2:
                    beta_info[i][j] = beta_control[1]
                    beta_info[j][i] = beta_info[i][j]
                if numbers[num_idx] == 3:
                    beta_info[i][j] = beta_control[2]
                    beta_info[j][i] = beta_info[i][j]
                
                num_idx += 1        
    
    return beta_info, eligible


def dist_cal(O, D):
    '''Given the coords of O and D, calculate the distance between the points'''
    dist = np.linalg.norm(D - O)
    return dist     
        
        
def small(n,cfg):
    '''
    Parameters
    ----------
    n : int
        # of passengers in the small case scenario (4)
    cfg : Config object
        a configuration containing all parameters

    Returns
    -------
    OD_mat : (n+1)*2 array
        an array storing all the ODs (first 2 row driver's OD, the rest is n Os and n Ds)
        x | y 
    det_info : n*1 vector
        a vector storing the detour values for passengers
    alpha_info : n*1 vector
        a vector storing the aloha values for passengers

    '''

    if cfg.pat == 'random':
        multiplyer = np.array([[cfg.size[0],0],[0,cfg.size[1]]])
        O_mat = np.random.rand(n+1,2) @ multiplyer
        D_mat = np.random.rand(n+1,2) @ multiplyer
        
    if cfg.pat == 'cluster':
        multiplyer = np.array([[cfg.r_cluster,0],[0, cfg.r_cluster]])
        shifter = np.tile(np.array([cfg.size[0]-cfg.r_cluster, cfg.size[1]-cfg.r_cluster]),(n+1,1))
        O_mat = np.random.rand(n+1,2) @ multiplyer
        D_mat = np.random.rand(n+1,2) @ multiplyer + shifter  
    
    alpha_all = np.linalg.norm(D_mat - O_mat, axis=1)
    F = alpha_all[0]
    alpha_info = alpha_all[1:]
    
    det_info = np.zeros(n)
    for i in range(n):
        det_info[i] = dist_cal(O_mat[i+1], O_mat[0]) + alpha_info[i] + dist_cal(D_mat[i+1], D_mat[0]) - F
    
    OD_mat = np.concatenate((O_mat[0].reshape(1,2),D_mat[0].reshape(1,2),O_mat[1:,:],D_mat[1:,:]))
     
        
    return OD_mat, det_info, alpha_info, F


def large(n,m,cfg):
    '''
    Parameters
    ----------
    n : int
        # of passengers in large case
    m : int
        # of vehicles in large case
    cfg : Config object
        a configuration containing all parameters

    Returns
    -------
    OD_mat : (n+m)*2 array
        an array storing all the ODs (first 2*m row driver's OD, the rest is 2*n pas's OD)
        x | y 
    det_info : n*1 vector
        a vector storing the detour values for passengers
    alpha_info : n*1 vector
        a vector storing the aloha values for passengers

    '''

    if cfg.pat == 'random':
        multiplyer = np.array([[cfg.size[0],0],[0,cfg.size[1]]])
        O_mat = np.random.rand(n+m,2) @ multiplyer
        D_mat = np.random.rand(n+m,2) @ multiplyer
        
    if cfg.pat == 'cluster':
        multiplyer = np.array([[cfg.r_cluster,0],[0, cfg.r_cluster]])
        shifter = np.tile(np.array([cfg.size[0]-cfg.r_cluster, cfg.size[1]-cfg.r_cluster]),(n+m,1))
        O_mat = np.random.rand(n+m,2) @ multiplyer
        D_mat = np.random.rand(n+m,2) @ multiplyer + shifter  
    
    alpha_all = np.linalg.norm(D_mat - O_mat, axis=1)
    F = alpha_all[:m]
    alpha_info = alpha_all[m:]
    
    det_info = np.zeros((n,m))
    for i in range(n):
        for j in range(m):
            det_info[i][j] = dist_cal(O_mat[i+m], O_mat[j]) + alpha_info[i] + dist_cal(D_mat[i+m], D_mat[j]) - F[j]
    
    OD_mat = np.concatenate((O_mat[:m,:],D_mat[:m,:],O_mat[m:,:],D_mat[m:,:]))
    
    int_time_info = np.random.exponential(scale = cfg.mu_t, size = n)
    int_time_info = np.append(int_time_info,1000)
    
    #generate list of radius following the indices in OD_mat
    rad_v = 2*m*[0]
    radius_list = rad_v
   
    rad_p = np.random.rand(n)*cfg.rad_max
    rad_p = rad_p.tolist()
    # rad_p = [cfg.rad_max]*n
    radius_list = rad_v + 2*rad_p
    
    #generate passenger and vehicle instances
    passengers = []
    vehicles = []
    for i in range(n):
        pas = Passenger(i, alpha_info[i], O_mat[i+m], D_mat[i+m], int_time_info[i], cfg)
        passengers.append(pas)
    for j in range(m):
        v = Vehicle(j, F[j], O_mat[j], D_mat[j], cfg)
        vehicles.append(v)
        
    return OD_mat, radius_list, int_time_info, passengers, vehicles


def large_parallel(n,m,cfg,rng):
    '''
    Parameters
    ----------
    n : int
        # of passengers in large case
    m : int
        # of vehicles in large case
    cfg : Config object
        a configuration containing all parameters

    Returns
    -------
    OD_mat : (n+m)*2 array
        an array storing all the ODs (first 2*m row driver's OD, the rest is 2*n pas's OD)
        x | y 
    det_info : n*1 vector
        a vector storing the detour values for passengers
    alpha_info : n*1 vector
        a vector storing the aloha values for passengers

    '''

    if cfg.pat == 'random':
        multiplyer = np.array([[cfg.size[0],0],[0,cfg.size[1]]])
        O_mat = rng.rand(n+m,2) @ multiplyer
        D_mat = rng.rand(n+m,2) @ multiplyer
        
    if cfg.pat == 'cluster':
        multiplyer = np.array([[cfg.r_cluster,0],[0, cfg.r_cluster]])
        shifter = np.tile(np.array([cfg.size[0]-cfg.r_cluster, cfg.size[1]-cfg.r_cluster]),(n+m,1))
        O_mat = rng.rand(n+m,2) @ multiplyer
        D_mat = rng.rand(n+m,2) @ multiplyer + shifter  
    
    alpha_all = np.linalg.norm(D_mat - O_mat, axis=1)
    F = alpha_all[:m]
    alpha_info = alpha_all[m:]
    
    det_info = np.zeros((n,m))
    for i in range(n):
        for j in range(m):
            det_info[i][j] = dist_cal(O_mat[i+m], O_mat[j]) + alpha_info[i] + dist_cal(D_mat[i+m], D_mat[j]) - F[j]
    
    OD_mat = np.concatenate((O_mat[:m,:],D_mat[:m,:],O_mat[m:,:],D_mat[m:,:]))
    
    int_time_info = rng.exponential(scale = cfg.mu_t, size = n)
    int_time_info = np.append(int_time_info,1000)
    
    #generate list of radius following the indices in OD_mat
    rad_v = 2*m*[0]
    radius_list = rad_v
   
    rad_p = rng.rand(n)*cfg.rad_max
    rad_p = rad_p.tolist()
    radius_list = rad_v + 2*rad_p
    
    #generate passenger and vehicle instances
    passengers = []
    vehicles = []
    for i in range(n):
        pas = Passenger(i, alpha_info[i], O_mat[i+m], D_mat[i+m], int_time_info[i], cfg)
        passengers.append(pas)
    for j in range(m):
        v = Vehicle(j, F[j], O_mat[j], D_mat[j], cfg)
        vehicles.append(v)
        
    return OD_mat, radius_list, int_time_info, passengers, vehicles


def LA_data(n,m, OD_prob, LA_loc, radius):
    '''
    Parameters
    ----------
    n : int
        # of passengers in large case
    m : int
        # of vehicles in large case
    cfg : Config object
        a configuration containing all parameters

    Returns
    -------
    OD_mat : (n+m)*2 array
        an array storing all the ODs (first 2*n row pas's OD, the rest is 2*m driver's OD)
        x | y 
    det_info : n*1 vector
        a vector storing the detour values for passengers
    alpha_info : n*1 vector
        a vector storing the aloha values for passengers

    '''
    # OD_prob = np.loadtxt('LA_ODmat.txt')
    # LA_loc = np.loadtxt('LA_loc.txt')
    O_size = 17
    D_size = 16
    r_cluster = radius
    
    probs = np.random.rand(n+m)
    O_mat = np.zeros((n+m,2))
    D_mat = np.zeros((n+m,2))
    for k in range(n+m):
        O_idx,D_idx = 0,0
        prob = probs[k]
        for i,j in product(range(O_size), range(D_size)):
            prob -= OD_prob[i][j]
            if prob < 0:
                O_idx = i
                D_idx = j
                break
        O = LA_loc[O_idx];
        D = LA_loc[D_idx + O_size]
        theta = np.random.rand(2)*2*np.pi
        
        temp = np.array([r_cluster*np.cos(theta),r_cluster*np.sin(theta)])
        O_mat[k] = O + temp[:,0]
        D_mat[k] = D + temp[:,1]
        
    
    alpha_all = np.linalg.norm(D_mat - O_mat, axis=1)
    F = alpha_all[:m]
    alpha_info = alpha_all[m:]
    
    det_info = np.zeros((n,m))
    for i in range(n):
        for j in range(m):
            det_info[i][j] = dist_cal(O_mat[i+m], O_mat[j]) + alpha_info[i] + dist_cal(D_mat[i+m], D_mat[j]) - F[j]
    
    OD_mat = np.concatenate((O_mat[:m,:],D_mat[:m,:],O_mat[m:,:],D_mat[m:,:]))
     
        
    return OD_mat, det_info, alpha_info, F


def LA_data_new(n,m, OD_prob, LA_loc, cfg):
    '''
    Parameters
    ----------
    n : int
        # of passengers in large case
    m : int
        # of vehicles in large case
    cfg : Config object
        a configuration containing all parameters

    Returns
    -------
    OD_mat : (n+m)*2 array
        an array storing all the ODs (first 2*m row driver's OD, the rest is 2*n pas's OD)
        x | y 
    det_info : n*m array
        an array storing the detour values for passengers n to vehicle m
    alpha_info : n*1 vector
        a vector storing the aloha values for passengers

    '''
    # OD_prob = np.loadtxt('LA_ODmat.txt')
    # LA_loc = np.loadtxt('LA_loc.txt')
    O_size = 17
    D_size = 16
    r_cluster = 5
    
    probs = np.random.rand(n+m)
    O_mat = np.zeros((n+m,2))
    D_mat = np.zeros((n+m,2))
    for k in range(n+m):
        O_idx,D_idx = 0,0
        prob = probs[k]
        for i,j in product(range(O_size), range(D_size)):
            prob -= OD_prob[i][j]
            if prob < 0:
                O_idx = i
                D_idx = j
                break
        O = LA_loc[O_idx];
        D = LA_loc[D_idx + O_size]
        theta = np.random.rand(2)*2*np.pi
        
        temp = np.array([r_cluster*np.cos(theta),r_cluster*np.sin(theta)])
        O_mat[k] = O + temp[:,0]
        D_mat[k] = D + temp[:,1]
        
    alpha_all = np.linalg.norm(D_mat - O_mat, axis=1)
    F = alpha_all[:m]
    alpha_info = alpha_all[m:]
    int_time_info = np.random.exponential(scale = cfg.mu_t_sys, size = n)
    int_time_info = np.append(int_time_info,1000)
    #generate passenger and vehicle instances
    passengers = []
    vehicles = []
    for i in range(n):
        pas = Passenger(i, alpha_info[i], O_mat[i+m], D_mat[i+m], int_time_info[i], cfg)
        passengers.append(pas)
    for j in range(m):
        v = Vehicle(j, F[j], O_mat[j], D_mat[j], cfg)
        vehicles.append(v)
        
    det_info = np.zeros((n,m))
    for i in range(n):
        for j in range(m):
            det_info[i][j] = dist_cal(O_mat[i+m], O_mat[j]) + alpha_info[i] + dist_cal(D_mat[i+m], D_mat[j]) - F[j]
    
    OD_mat = np.concatenate((O_mat[:m,:],D_mat[:m,:],O_mat[m:,:],D_mat[m:,:]))
     
        
    return OD_mat, det_info, int_time_info, passengers, vehicles #consider removing some outputs later